Школа глубокого обучения ФПМИ МФТИ

Домашнее задание. Сегментация изображений

В этом задании вам предстоит решить задачу сегментации медицинских снимков. Домашнее задание можно разделить на следующие части:

  • Построй свой первый бейзлайн! [6]
    • BCE Loss [2]
    • SegNet [2]
    • Train [1]
    • Test [1]
  • Мир других лоссов! [2]
    • Dice Loss [1]
    • Focal Loss [1]
    • BONUS: лосс из статьи [5]
  • Новая модель! [2]
    • UNet [2]

Максимальный балл: 10 баллов.

Также для студентов желающих еще более углубиться в задачу предлагается решить бонусное задание, которое даст дополнительные 5 баллов. BONUS задание необязательное.

Шаг 1. Загрузка и подготовка данных¶

  1. Для начала мы скачаем датасет: ADDI project.
No description has been provided for this imageNo description has been provided for this image
  1. Разархивируем .rar файл.
  2. Обратите внимание, что папка PH2 Dataset images должна лежать там же где и ipynb notebook.

Это фотографии двух типов поражений кожи: меланома и родинки. В данном задании мы не будем заниматься их классификацией, а будем сегментировать их.

In [ ]:
!gdown 1T_RPkPP0jeWwK8L1UrmBw8V30eD7v6Ql
Downloading...
From (original): https://drive.google.com/uc?id=1T_RPkPP0jeWwK8L1UrmBw8V30eD7v6Ql
From (redirected): https://drive.google.com/uc?id=1T_RPkPP0jeWwK8L1UrmBw8V30eD7v6Ql&confirm=t&uuid=2085d32f-a224-4c84-b7a5-6dfb511a68d1
To: /content/PH2Dataset.rar
100% 162M/162M [00:01<00:00, 125MB/s]
In [ ]:
get_ipython().system_raw("unrar x PH2Dataset.rar")

Стуктура датасета у нас следующая:

IMD_002/
    IMD002_Dermoscopic_Image/
        IMD002.bmp
    IMD002_lesion/
        IMD002_lesion.bmp
    IMD002_roi/
        ...
IMD_003/
    ...
    ...

 

Здесь X.bmp — изображение, которое нужно сегментировать, X_lesion.bmp — результат сегментации.

Для загрузки датасета можно использовать skimage: skimage.io.imread()

In [ ]:
images = []
lesions = []
from skimage.io import imread
import os
root = 'PH2Dataset'

for root, dirs, files in os.walk(os.path.join(root, 'PH2 Dataset images')):
    if root.endswith('_Dermoscopic_Image'):
        images.append(imread(os.path.join(root, files[0])))
    if root.endswith('_lesion'):
        lesions.append(imread(os.path.join(root, files[0])))

Изображения имеют разные размеры. Давайте изменим их размер на $256\times256 $ пикселей. Для изменения размера изображений можно использовать skimage.transform.resize(). Эта функция также автоматически нормализует изображения в диапазоне $[0,1]$.

In [ ]:
from skimage.transform import resize
size = (256, 256)
X = [resize(x, size, mode='constant', anti_aliasing=True,) for x in images]
Y = [resize(y, size, mode='constant', anti_aliasing=False) > 0.5 for y in lesions]
In [ ]:
import numpy as np
X = np.array(X, np.float32)
Y = np.array(Y, np.float32)
print(f'Loaded {len(X)} images')
Loaded 200 images

Чтобы убедиться, что все корректно, мы нарисуем несколько изображений

In [ ]:
import matplotlib.pyplot as plt
from IPython.display import clear_output

plt.figure(figsize=(18, 6))
for i in range(6):
    plt.subplot(2, 6, i+1)
    plt.axis("off")
    plt.imshow(X[i])

    plt.subplot(2, 6, i+7)
    plt.axis("off")
    plt.imshow(Y[i])
plt.show();
No description has been provided for this image

Разделим наши 200 картинок на 100/50/50 для обучения, валидации и теста соответственно

In [ ]:
ix = np.random.choice(len(X), len(X), False)
tr, val, ts = np.split(ix, [100, 150])
In [ ]:
print(len(tr), len(val), len(ts))
100 50 50

PyTorch DataLoader¶

In [ ]:
from torch.utils.data import DataLoader
batch_size = 25
train_dataloader = DataLoader(list(zip(np.rollaxis(X[tr], 3, 1), Y[tr, np.newaxis])),
                     batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(list(zip(np.rollaxis(X[val], 3, 1), Y[val, np.newaxis])),
                      batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(list(zip(np.rollaxis(X[ts], 3, 1), Y[ts, np.newaxis])),
                     batch_size=batch_size, shuffle=False)
In [ ]:
loaders = {'train':train_dataloader, 'val': valid_dataloader}
In [ ]:
import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
cuda
In [ ]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
In [ ]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Шаг 2. Метрика качества модели¶

IoU (intersection over union)¶

В данном разделе предлагается использовать следующую метрику для оценки качества:

$I o U=\frac{\text {target } \cap \text { prediction }}{\text {target } \cup{prediction }}$

Пересечение (A ∩ B) состоит из пикселей, найденных как в маске предсказания, так и в основной маске истины, тогда как объединение (A ∪ B) просто состоит из всех пикселей, найденных либо в маске предсказания, либо в целевой маске.

Что будет являться пересением и объединением в задаче сегментации?

Давайте разберем следующий пример:

No description has been provided for this image

In [ ]:
!pip install -q torchmetrics
In [ ]:
from torchmetrics import JaccardIndex

iou_score = JaccardIndex(threshold=0.5, task="binary", average='none').to(DEVICE)

Задания: Построй свой первый бейзлайн!¶

Итак, загрузка файлов, код датасета и даталоадера написана за вас. Метрика IoU написана за вас! Вам остается написать лосс, модель и функции обучения и теста модели.

  • Построй свой первый бейзлайн! [6]
    • BCE Loss [2]
    • SegNet [2]
    • Train [1]
    • Test [1]

Шаг 3. Loss функция - BCE [2 балла]¶

Популярным лоссом для бинарной сегментации является бинарная кросс-энтропия, которая задается следующим образом:

$$\mathcal L_{BCE}(y, \hat y) = -\sum_i \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right] \space [1]$$

где $y$ это таргет желаемого результата и $\hat y$ является выходом модели. $\sigma$ - это логистическая функция, который преобразует действительное число $\mathbb R$ в вероятность $[0,1]$.

Однако эта потеря страдает от проблем численной нестабильности. Самое главное, что $\lim_{x\rightarrow0}\log(x)=\infty$ приводит к неустойчивости в процессе оптимизации. Рекомендуется посмотреть следующее упрощение. Эта функция эквивалентна первой и не так подвержена численной неустойчивости:

$$\mathcal L_{BCE} = \hat y - y\hat y + \log\left(1+\exp(-\hat y)\right) \space [2]$$

Вывод численно стабильной формулы BCE лосса [1 балл]¶

Выведите из формулы [1] формулу [2]:

$$\mathcal L_{BCE}(y, \hat y) = -\sum_i \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right] \space [1]$$

$$\mathcal L_{BCE} = \hat y - y\hat y + \log\left(1+\exp(-\hat y)\right) \space [2]$$

Не забываем, что здесь $\hat y_i$ - это логиты сети, не вероятности и не лейблы.

Ответ:

$$\mathcal \log\sigma(\hat y_i) = \log\frac{1}{1+e^{-\hat y_i}}=\log1 - \log(1+e^{-\hat y_i}) = - \log(1+e^{-\hat y_i})$$

$$\mathcal \log(1 - \sigma(\hat y_i)) = \log(1 - \frac{1}{1+e^{-\hat y_i}}) = \log\frac{e^{-\hat y_i}}{1+e^{-\hat y_i}} = \log(e^{-\hat y_i}) - \log(1+e^{-\hat y_i}) = {-\hat y_i} - \log(1+e^{-\hat y_i})$$

$$\mathcal L_{BCE}(y, \hat y) = -\sum_i \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right] = - \sum_i \left[y_i( - \log(1+e^{-\hat y_i})) + (1-y_i)({-\hat y_i} - \log(1+e^{-\hat y_i}))\right]=- \sum_i \left[-y_i \log(1+e^{-\hat y_i}) -{\hat y_i} - \log(1+e^{-\hat y_i}) + y_i{\hat y_i} + y_i \log(1+e^{-\hat y_i})\right] = \sum_i \left[{\hat y_i} - y_i{\hat y_i} + \log(1+e^{-\hat y_i}) \right]$$

Реализуйте в коде оба варианта лосса [1 балл]¶

Реализуйте следующие функции:

  • bce_true() - честная прямая реализация лосса с формулой $$\mathcal L_{BCE}(y, \hat y) = -\sum_i \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right].$$
  • bce_loss() - реализация формулы, которую мы вывели $$\mathcal L_{BCE} = \hat y - y\hat y + \log\left(1+\exp(-\hat y)\right).$$

И сравните результаты функций с реализацией Pytorch:

  • bce_torch()
  • bce_torch_with_logits()
In [ ]:
import torch.nn.functional as F
import torch.nn as nn
In [ ]:
bce_torch = nn.BCELoss(reduction='sum') # (sigmoid(y_pred), y_real)
bce_torch_with_logits = nn.BCEWithLogitsLoss(reduction='sum')
In [ ]:
def bce_loss(y_pred, y_real):
  return torch.sum(y_pred - y_pred*y_real + torch.log(1 + torch.exp(-y_pred)))

def bce_true(y_pred, y_real):
  y = y_real
  p = 1 / (1 + torch.exp(-y_pred))
  return -torch.sum(y * torch.log(p) + (1 - y) * torch.log(1 - p))

Проверим корректность работы на простом примере

In [ ]:
y_pred = torch.randn(3, 2, requires_grad=False)
y_true = torch.rand(3, 2, requires_grad=False)

print(f'BCE loss from scratch bce_loss             = {bce_loss(y_pred, y_true)}')
print(f'BCE loss честно посчитанный                = {bce_true(y_pred, y_true)}')
print(f'BCE loss from torch bce_torch              = {bce_torch(torch.sigmoid(y_pred), y_true)}')
print(f'BCE loss from torch with logits bce_torch  = {bce_torch_with_logits(y_pred, y_true)}')
BCE loss from scratch bce_loss             = 4.616026401519775
BCE loss честно посчитанный                = 4.616025924682617
BCE loss from torch bce_torch              = 4.616025924682617
BCE loss from torch with logits bce_torch  = 4.616025924682617

Инструкции assert в Python — это булевы выражения, которые проверяют, является ли условие истинным (True). Внизу в коде мы проверяем функция bce_loss() выдает тот же результат, что и функция из Pytorch или нет. Если равенства не будет, что будет означать, что результаты функций не совпадают, а значит вы неправильно реализовали фукнцию bce_loss(), assert возвратит ошибку.

Функция numpy.isclose() используется для сравнения двух чисел с учётом допустимой погрешности. Она особенно полезна при работе с числами с плавающей точкой, где точное сравнение может быть проблематичным из-за ограничений представления таких чисел в компьютере.

Как она работает?

numpy.isclose(a, b, rtol=1e-05, atol=1e-08) принимает два числа (a и b) и сравнивает их, учитывая относительную и абсолютную погрешность. Если разница между двумя числами меньше заданного порога, функция возвращает True, иначе — False.

Параметры:

rtol: Относительная погрешность (по умолчанию 1e-05). Используется для определения разницы относительно большего значения.
atol: Абсолютная погрешность (по умолчанию 1e-08). Определяет минимальную разницу, которую следует учитывать.

Мы будем использовать assert и numpy.isclose() для проверки корректности нашего кода.

In [ ]:
assert np.isclose(bce_loss(y_pred, y_true), bce_torch(torch.sigmoid(y_pred), y_true))
assert np.isclose(bce_loss(y_pred, y_true), bce_torch_with_logits(y_pred, y_true))
assert np.isclose(bce_true(y_pred, y_true), bce_torch(torch.sigmoid(y_pred), y_true))
assert np.isclose(bce_true(y_pred, y_true), bce_torch_with_logits(y_pred, y_true))

Давайте теперь посчитаем на простом примере, но с теми же размерностями, что и в датасете

In [ ]:
y_pred = torch.randn((2, 1, 3, 3), requires_grad=False)
y_true = torch.randint(0, 2, (2, 1, 3, 3))

print(f'BCE loss from scratch bce_loss            = {bce_loss(y_pred, y_true)}')
print(f'BCE loss честно посчитанный               = {bce_true(y_pred, y_true)}')
print(f'BCE loss from torch bce_torch             = {bce_torch(torch.sigmoid(y_pred), y_true.to(torch.float))}')
print(f'BCE loss from torch with logits bce_torch = {bce_torch_with_logits(y_pred, y_true.to(torch.float))}')
BCE loss from scratch bce_loss            = 14.737800598144531
BCE loss честно посчитанный               = 14.737801551818848
BCE loss from torch bce_torch             = 14.737801551818848
BCE loss from torch with logits bce_torch = 14.737800598144531
In [ ]:
assert np.isclose(bce_loss(y_pred, y_true), bce_torch(torch.sigmoid(y_pred), y_true.to(torch.float)))
assert np.isclose(bce_loss(y_pred, y_true), bce_torch_with_logits(y_pred, y_true.to(torch.float)))
assert np.isclose(bce_true(y_pred, y_true), bce_torch(torch.sigmoid(y_pred), y_true.to(torch.float)))
assert np.isclose(bce_true(y_pred, y_true), bce_torch_with_logits(y_pred, y_true.to(torch.float)))

Давайте посчитаем на реальных логитах и сегментационной маске:

In [ ]:
!gdown --folder 1EX0RW1TRQVkLmR1h6miCQqyhYPFyg28M
Retrieving folder contents
Processing file 1--WxvBdpMn_NOmYPf3a4au8MHzfx5baC labels.pt
Processing file 1-0A7_CS_vKiSCkgIDJ4joThCEcFedA3I logits.pt
Retrieving folder contents completed
Building directory structure
Building directory structure completed
Downloading...
From: https://drive.google.com/uc?id=1--WxvBdpMn_NOmYPf3a4au8MHzfx5baC
To: /content/for_asserts/labels.pt
100% 1.18k/1.18k [00:00<00:00, 5.16MB/s]
Downloading...
From: https://drive.google.com/uc?id=1-0A7_CS_vKiSCkgIDJ4joThCEcFedA3I
To: /content/for_asserts/logits.pt
100% 1.18k/1.18k [00:00<00:00, 519kB/s]
Download completed
In [ ]:
path_to_dummy_samples = '/content/for_asserts'
dummpy_sample = {'logits': torch.load(f'{path_to_dummy_samples}/logits.pt'),
                 'labels': torch.load(f'{path_to_dummy_samples}/labels.pt')}
dummpy_sample['labels'] = dummpy_sample['labels'].to(DEVICE)
dummpy_sample['logits'] = dummpy_sample['logits'].to(DEVICE)
In [ ]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize= (10,3*10))

ax1.imshow(dummpy_sample['labels'].squeeze(1)[0].cpu())
ax1.set_title("Original")

ax2.imshow(dummpy_sample['logits'].sigmoid().squeeze(1)[0].cpu())
for (j,i),label in np.ndenumerate(dummpy_sample['logits'].sigmoid().squeeze(1)[0].cpu()):
    if label < 0.5:
        color = 'white'
    else:
        color = 'black'
    ax2.text(i,j,round(label,3), color=color, ha='center',va='center')

ax2.set_title("Predicted Probabilities")

ax3.imshow((dummpy_sample['logits'].sigmoid() > 0.5).squeeze(1)[0].cpu())
ax3.set_title("Predicted Mask")
plt.show()
No description has been provided for this image

Проверяем на данном примере:

In [ ]:
bce_loss_score = bce_loss(dummpy_sample['logits'].cpu(), dummpy_sample['labels'].cpu())
bce_true_score = bce_true(dummpy_sample['logits'].cpu(), dummpy_sample['labels'].cpu())
bce_torch_score = bce_torch(torch.sigmoid(dummpy_sample['logits'].cpu()), dummpy_sample['labels'].cpu().float())
bce_torch_with_logits_score = bce_torch_with_logits(dummpy_sample['logits'].cpu(), dummpy_sample['labels'].cpu().float())
assert np.isclose(bce_loss_score, bce_torch_score)
assert np.isclose(bce_loss_score, bce_torch_with_logits_score)
assert np.isclose(bce_true_score, bce_torch_score)
assert np.isclose(bce_true_score, bce_torch_with_logits_score)

Шаг 4. Модель SegNet [2 балла]¶

Ваше задание здесь состоит в том, чтобы реализовать SegNet архитектуру.

No description has been provided for this image
  • Badrinarayanan, V., Kendall, A., & Cipolla, R. (2015). SegNet: A deep convolutional encoder-decoder architecture for image segmentation
In [ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from time import time

from matplotlib import rcParams
rcParams['figure.figsize'] = (15,4)

Внимательно посмотрите из чего состоит модель и для чего выбраны те или иные блоки. Для этого скачаем и изучим feature extractor VGG-16, который лежит в основе SegNet.

In [ ]:
model_vgg16 = models.vgg16(weights = models.VGG16_Weights.IMAGENET1K_V1)
In [ ]:
model_vgg16
Out[ ]:
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

Feature extractor VGG-16 состоит из 5 блоков:

  • два блока со структурой: Conv2d -> ReLU -> Conv2d -> ReLU -> MaxPool2d
  • три блока со структурой: Conv2d -> ReLU -> Conv2d -> ReLU -> Conv2d -> ReLU -> MaxPool2d

В первом блоке - на входе три канала (по числу каналов в изображениях), которые конволюционный слой преобразует в 64 канала.

Во втором, третьем и четвертом блоках первый конволюционный слой удваивает количество каналов, а последующие конволюционные слои не меняют количество каналов.

В последнем блоке число каналов от слоя к слою не меняется.

Теперь напишем код одного блока энкодера нашей модели SegNet.

In [ ]:
# Параметрами блока будут:
# - количество каналов на входе
# - количество каналов на выходе
# - глубина блока (2 или 3, по количеству конволюционных слоев)
# - kernel_size и padding
#
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, depth, kernel_size = 3, padding = 1):
        super(EncoderBlock, self).__init__() # инициируем экземляр класса, наследующего от nn.Module
        self.layers = nn.ModuleList() # в self.layers будем добавлять слои блока
        # дальше реализуем то, что на картинке выше обозначено Conv + Batch Normalization + ReLU
        self.layers.append(nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, padding = padding))
        self.layers.append(nn.BatchNorm2d(out_channels))
        self.layers.append(nn.ReLU(inplace=True))

        # цикл for помогает использовать один код для блоков как с глубиной 2, так и с глубиной 3
        for i in range(depth-1):
            self.layers.append(nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = kernel_size, padding = padding))
            self.layers.append(nn.BatchNorm2d(out_channels))
            self.layers.append(nn.ReLU(inplace=True))

        self.maxpooling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) #добавляем MaxPool с индексами для последующего Unpooling

    # Обратите внимание: на вход метод forward() получает карту признаков (х),
    # а возвращает карту признаков и индексы для последующего Unpooling
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        size = x.size()
        x, indices = self.maxpooling(x)
        return x, indices, size

По аналогии напишите код одного блока декодера.

К карте признаков на входе каждого блока примеяется nn.MaxUnpool2d с индексами из симметричного блока энкодера. Затем повторяется связка Conv + Batch Normalization + ReLU. Количество каналов меняется зеркально блокам энкодера:

  • в первом блоке декодера количество каналов не меняется
  • во 2-4 блоках декодера количество каналов уменьшается в 2 раза после прохождения последнего конволюционного слоя
  • на выходе из последнего блока декодера 1 канал

Обратите внимание, что после последней конволюции последнего блока декодера не применяется батч-нормализация и функция активации.

In [ ]:
class DecoderBlock(nn.Module):
  def __init__(self, in_channels, out_channels, depth, kernel_size = 3, padding = 1):
    super().__init__()

    self.layers = nn.ModuleList()
    for _ in range(depth):
      self.layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding))
      self.layers.append(nn.BatchNorm2d(out_channels))
      self.layers.append(nn.ReLU(inplace=True))
      in_channels = out_channels

    self.maxunpooling = nn.MaxUnpool2d(kernel_size=2, stride=2)

  def forward(self, x, indices, output_size):
    x = self.maxunpooling(x, indices, output_size)

    for layer in self.layers:
      x = layer(x)
    return x

Соединим блоки энкодера и декодера в модель SegNet:

In [ ]:
class SegNet(nn.Module):
    def __init__(self, in_channels=3, out_channels = 1, num_features = 64) -> None:
        super(SegNet, self).__init__()

        # Encoder
        self.encoder0 = EncoderBlock(in_channels, num_features, depth=2)
        self.encoder1 = EncoderBlock(num_features, num_features * 2, depth=2)
        self.encoder2 = EncoderBlock(num_features * 2, num_features * 4, depth=3)
        self.encoder3 = EncoderBlock(num_features * 4, num_features * 8, depth=3)

        # Encoder bottleneck - количество каналов на входе и на выходе одинаково
        self.encoder4 = EncoderBlock(num_features * 8, num_features * 8, depth=3)

        # Decoder bottleneck
        self.decoder0 = DecoderBlock(num_features * 8, num_features * 8, depth=3)

        # Decoder
        self.decoder1 = DecoderBlock(num_features * 8, num_features * 4, depth=3)
        self.decoder2 = DecoderBlock(num_features * 4, num_features * 2, depth=3)
        self.decoder3 = DecoderBlock(num_features * 2, num_features, depth=2)
        self.decoder4 = DecoderBlock(num_features, num_features, depth=2)
        self.final = nn.Conv2d(num_features, out_channels, kernel_size=1)

    def forward(self, x):
        # encoder
        e0, ind0, size0 = self.encoder0(x)
        e1, ind1, size1 = self.encoder1(e0)
        e2, ind2, size2 = self.encoder2(e1)
        e3, ind3, size3 = self.encoder3(e2)
        e4, ind4, size4 = self.encoder4(e3)

        # Decoder
        d0 = self.decoder0(e4, ind4, size4)
        d1 = self.decoder1(d0, ind3, size3)
        d2 = self.decoder2(d1, ind2, size2)
        d3 = self.decoder3(d2, ind1, size1)
        d4 = self.decoder4(d3, ind0, size0)
        output = self.final(d4)
        return output  # no activation

Шаг 5. Тренировка модели [1 балл]¶

Напишите функции для обучения модели.

In [ ]:
from tqdm.notebook import tqdm
In [ ]:
def fit_one_epoch(model, train_dataloader, optimizer, loss_func):
  '''
  args:
    model - модель для обучения
    train_dataloader - loader с выборкой для обучения модели
    optimizer - оптимизатор, взятый из модуля `torch.optim`
    loss_func - функция потерь, взятая из модуля `torch.nn`

  функция возвращает метрику accuracy по эпохе на данных из train_dataloader
  '''

  model.train()
  avg_loss = 0
  visualized = False

  for X_batch, y_batch in tqdm(train_dataloader):
    X_batch = X_batch.to(DEVICE)
    y_batch = y_batch.to(DEVICE)

    optimizer.zero_grad()
    outp = model(X_batch)

    prob = torch.sigmoid(outp)
    y_pred = (prob > 0.5).long()
    if not visualized:
      visualized = True
      visualize(X_batch, y_batch, y_pred)

    loss = loss_func(outp, y_batch)
    loss.backward()
    optimizer.step()

    avg_loss += loss.item()

  return avg_loss / len(train_dataloader)
In [ ]:
def eval_one_epoch(model, val_dataloader, loss_func):
  '''
  args:
    model - модель для обучения
    val_dataloader - loader с валидационной/тестовой выборкой
  '''
  iou_score = JaccardIndex(threshold=0.5, task="binary", average='none').to(DEVICE)

  model.eval()
  avg_loss = 0
  avg_iou = 0

  visualized = False

  with torch.no_grad():
    for X_batch, y_batch in tqdm(val_dataloader):
      X_batch = X_batch.to(DEVICE)
      y_batch = y_batch.to(DEVICE)

      outp = model(X_batch)
      prob = torch.sigmoid(outp)
      y_pred = (prob > 0.5).long()

      loss = loss_func(outp, y_batch)
      iou = iou_score(y_pred, y_batch)

      avg_loss += loss.item()
      avg_iou += iou

      if not visualized:
        visualized = True
        visualize(X_batch, y_batch, y_pred)

    avg_loss = avg_loss / len(val_dataloader)
    avg_iou = avg_iou / len(val_dataloader)

  return avg_loss, avg_iou
In [ ]:
def visualize(X_batch, y_batch, pred, n=1):
  batch_size = X_batch.shape[0]

  plt.figure(figsize=(10, 3*n))
  for i in range(n):
    img = X_batch[i].permute(1,2,0).cpu().numpy()
    true_mask = y_batch[i].cpu().squeeze().numpy()
    pred_mask = pred[i].squeeze().cpu().numpy()

    plt.subplot(n, 3, i*3 + 1)
    plt.title("Image")
    plt.imshow(img)
    plt.axis("off")

    plt.subplot(n, 3, i*3 + 2)
    plt.title("GT mask")
    plt.imshow(true_mask, cmap="nipy_spectral")
    plt.axis("off")

    plt.subplot(n, 3, i*3 + 3)
    plt.title("Pred mask")
    plt.imshow(pred_mask, cmap="nipy_spectral")
    plt.axis("off")

  plt.show()
In [ ]:
def train_func(model, num_epochs, dataloaders, optimizer, loss_func):
  '''
  args:
    model - модель для обучения
    num_epochs - количество эпох
    dataloaders - словарь loader'ов с обучающей и валидационной выборками
    optimizer - оптимизатор, взятый из модуля `torch.optim`
    loss_func - функция потерь, взятая из модуля `torch.nn`

  функция возвращает loss на обучающей и валидационной выборках на каждой эпохе, а также метрику IoU на валидационной выборке
  '''
  model = model.to(DEVICE)
  score = {"train_loss": [], "val_loss": [], "val_iou": []}
  for epoch in range(num_epochs):
    print(f"\nEpoch: {epoch+1}")

    loss_train = fit_one_epoch(model = model, train_dataloader = dataloaders['train'], optimizer = optimizer, loss_func = loss_func)
    print(f"Loss train: {loss_train}\n")

    loss_val, iou_val = eval_one_epoch(model = model, val_dataloader = dataloaders['val'], loss_func=loss_func)
    print(f"Loss valid: {loss_val}\n")
    print(f"IoU valid: {iou_val}\n")


    score['train_loss'].append(loss_train)
    score['val_loss'].append(loss_val)
    score['val_iou'].append(iou_val)
  return model, score

Обучите модель SegNet. В качестве оптимайзера можно взять Adam.

In [ ]:
model_baseline = SegNet()
In [ ]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model_baseline.parameters(), lr = 1e-3)
In [ ]:
model_baseline, score_baseline = train_func(model_baseline, 50, loaders, optimizer, criterion)
Epoch: 1
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.7088902145624161

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.6806586980819702

IoU valid: 0.0


Epoch: 2
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.6382251232862473

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.6685355007648468

IoU valid: 0.0


Epoch: 3
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.500316396355629

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.6420924067497253

IoU valid: 0.38382676243782043


Epoch: 4
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.4157711789011955

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.6269842386245728

IoU valid: 0.42938777804374695


Epoch: 5
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.37223953753709793

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.6746330857276917

IoU valid: 0.3734997510910034


Epoch: 6
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.3497515767812729

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.6309003829956055

IoU valid: 0.4095335602760315


Epoch: 7
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.3247246891260147

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.6306696534156799

IoU valid: 0.43107739090919495


Epoch: 8
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.3140593394637108

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.704652726650238

IoU valid: 0.407431960105896


Epoch: 9
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.3083975985646248

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.4111880213022232

IoU valid: 0.5894871950149536


Epoch: 10
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.2947194203734398

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.5216855704784393

IoU valid: 0.5018594264984131


Epoch: 11
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.28925783932209015

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.30654121935367584

IoU valid: 0.6960251331329346


Epoch: 12
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.28432005643844604

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.4548138678073883

IoU valid: 0.5753812789916992


Epoch: 13
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.2550861984491348

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.4729755222797394

IoU valid: 0.5680891275405884


Epoch: 14
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.24712393060326576

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.4965546280145645

IoU valid: 0.5238682627677917


Epoch: 15
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.24329785630106926

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.5635707378387451

IoU valid: 0.5173154473304749


Epoch: 16
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.23090828210115433

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.4164176285266876

IoU valid: 0.5875464081764221


Epoch: 17
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.2251473069190979

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.40885478258132935

IoU valid: 0.6005829572677612


Epoch: 18
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.2389150969684124

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.4225718677043915

IoU valid: 0.6117110848426819


Epoch: 19
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.21030806750059128

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.5596248060464859

IoU valid: 0.5397874712944031


Epoch: 20
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.23217927664518356

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.26289648562669754

IoU valid: 0.735488772392273


Epoch: 21
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.24586381763219833

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.6872997879981995

IoU valid: 0.4611112177371979


Epoch: 22
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.22472476214170456

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.34815242886543274

IoU valid: 0.6736955046653748


Epoch: 23
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.23693016543984413

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.22739911824464798

IoU valid: 0.7620818614959717


Epoch: 24
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.23085859790444374

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.3239000588655472

IoU valid: 0.6869162321090698


Epoch: 25
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.22607507184147835

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.5454607158899307

IoU valid: 0.5920796990394592


Epoch: 26
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.21723804995417595

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.3899311274290085

IoU valid: 0.670615553855896


Epoch: 27
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.19333281368017197

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.2741902843117714

IoU valid: 0.7241814732551575


Epoch: 28
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.21157973632216454

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.38126857578754425

IoU valid: 0.6279032230377197


Epoch: 29
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.18963951990008354

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.27555491775274277

IoU valid: 0.7111327052116394


Epoch: 30
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.18125222250819206

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.17716316878795624

IoU valid: 0.8048577904701233


Epoch: 31
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.18316393345594406

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.27888912707567215

IoU valid: 0.7068072557449341


Epoch: 32
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.18919847905635834

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.2049947753548622

IoU valid: 0.7776749730110168


Epoch: 33
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.16759219393134117

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.34727734327316284

IoU valid: 0.6685512065887451


Epoch: 34
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.19003837928175926

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.2719089090824127

IoU valid: 0.7174249887466431


Epoch: 35
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.17706162855029106

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.1988501325249672

IoU valid: 0.7730058431625366


Epoch: 36
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.16342436894774437

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.1810378059744835

IoU valid: 0.8064121007919312


Epoch: 37
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.1796884685754776

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.21672917157411575

IoU valid: 0.7849995493888855


Epoch: 38
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.16774186491966248

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.21554523706436157

IoU valid: 0.7795153856277466


Epoch: 39
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.1882936768233776

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.22181197255849838

IoU valid: 0.7836636304855347


Epoch: 40
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.16846414655447006

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.2256392240524292

IoU valid: 0.7838817834854126


Epoch: 41
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.17166605219244957

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.3779662996530533

IoU valid: 0.6729598045349121


Epoch: 42
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.16688699647784233

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.24135907739400864

IoU valid: 0.7575369477272034


Epoch: 43
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.1529532317072153

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.2166372314095497

IoU valid: 0.7783081531524658


Epoch: 44
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.13592077419161797

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.22347378730773926

IoU valid: 0.7527279853820801


Epoch: 45
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.1690601073205471

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.31689217686653137

IoU valid: 0.6695342063903809


Epoch: 46
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.13234772719442844

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.334345281124115

IoU valid: 0.6826684474945068


Epoch: 47
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.1293657124042511

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.21620211750268936

IoU valid: 0.7776235342025757


Epoch: 48
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.14344285055994987

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.24861302226781845

IoU valid: 0.734126091003418


Epoch: 49
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.14622491039335728

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.1953466534614563

IoU valid: 0.7786290645599365


Epoch: 50
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
Loss train: 0.1356885675340891

  0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image
Loss valid: 0.2108900099992752

IoU valid: 0.7840803861618042

Шаг 6. Инференс [1 балл]¶

После обучения модели напишите функцию теста, воспользуйтесь лучшим чекпоинтом и протестируйте работу модели на тестовой выборке.

In [ ]:
def test(model, test_dataloader):
  model.eval()

  iou_score = JaccardIndex(threshold=0.5, task="binary", average='none').to(DEVICE)
  avg_iou = 0
  with torch.no_grad():
    for X_batch, Y_batch in test_dataloader:
      X_batch = X_batch.to(DEVICE)
      Y_batch = Y_batch.to(DEVICE)

      outp = model(X_batch)
      prob = torch.sigmoid(outp)
      y_pred = (prob > 0.5).long()
      avg_iou += iou_score(y_pred, Y_batch)

  return avg_iou / len(test_dataloader)
In [ ]:
test_score = test(model_baseline, test_dataloader)
In [ ]:
test_score.item()
Out[ ]:
0.8248778581619263

Задания: Мир других лоссов!¶

Пробуем другие функции потерь [2 балла]¶

В данном разделе вам потребуется имплементировать две функции потерь: DICE и Focal loss.

Dice Loss¶

1. Dice coefficient: Учитывая две маски $X$ и $Y$, общая метрика для измерения расстояния между этими двумя масками задается следующим образом:

$$D(X,Y)=\frac{2|X\cap Y|}{|X|+|Y|}$$

В терминах матрицы ошибок она будет считаться следующим образом:

$$D(X,Y) = \frac{2TP}{2TP + FP + FN}$$

Эта функция не является дифференцируемой, но это необходимое свойство для градиентного спуска. В данном случае мы можем приблизить его с помощью:

$$\mathcal L_D(X,Y) = 1- D(X, Y)$$

Hints (!):

  1. Не забудьте подумать о численной нестабильности, возникающей в математической формуле при ситуации, когда $\frac{0}{0}$, т.е. вам нужно добавить очень маленькое число, например $\epsilon = 1e^{-8}$, в обе части дроби при подсчете $D(X,Y)$:

$$D(X,Y) = \frac{2TP + ϵ}{2TP + FP + FN + ϵ}$$

  1. Dice метрика(!), не лосс, считается похожим образом как IoU:

    2.1. На вход вам приходят logits, т.е. значения от $-∞$ до $∞$. Их переводим в вероятности от 0 до 1 при помощи функции Sigmoid.

    2.2. Фиксируем порог, например threshold=0.5, и всему что ниже порога ставим значение 0, всему что выше 1. Получаем предсказанную маску из 0 и 1.

    2.3. Считаем TP, FP, FN

    2.4. Считаем DICE метрику по формуле

Вы можете прописать для себя функцию dice_score() и сравнить с результатами работы функции из библиотеки torchmetrics.

  1. Но с метрикой есть проблема, что она не дифференцируема, и если вы захотите просто взять и прописать dice_loss = 1 - dice_score, Pytorch поругается на вас и скажет, что это недифференцируемая метрика. Чтобы посчитать dice_loss делаем следующие шаги:

    3.1. На вход вам приходят logits, т.е. значения от $-∞$ до $∞$. Их переводим в вероятности от 0 до 1 при помощи функции Sigmoid.

    3.2. Здесь нам уже не нужно фиксировать порог, мы просто работаем с вероятностями. Значения вероятностей дифференцируемы и через них будут протекать градиенты.

    3.3. Считаем TP, FP, FN также как и в Dice метрике, только вместо маски, подаем вероятности.

    3.4. Считаем DICE метрику по формуле

    3.5. Считаем лосс как Loss = 1 - DICE

Итак, давайте сначала пропишем dice_score.

In [ ]:
def dice_score(logits: torch.Tensor, labels: torch.Tensor, threshold: float = 0.5):
  '''
  Это именно метрика, не лосс.
  '''
  eps = 1e-8

  prob = torch.sigmoid(logits)
  preds = (prob > threshold).int()

  TP = (preds * labels).sum()
  FP = (preds * (1 - labels)).sum()
  FN = ((1 - preds) * labels).sum()

  score = (2 * TP + eps) / (2 * TP + FP + FN + eps)

  return score

Проверим на корректность функцию dice_score:

In [ ]:
from torchmetrics.segmentation import DiceScore

dice = DiceScore(num_classes=1, average='micro').to(DEVICE)
dice(dummpy_sample['logits'].sigmoid() > 0.5, dummpy_sample['labels'].int())
/usr/local/lib/python3.12/dist-packages/torchmetrics/utilities/prints.py:43: UserWarning: DiceScore metric currently defaults to `average=micro`, but will change to`average=macro` in the v1.9 release. If you've explicitly set this parameter, you can ignore this warning.
  warnings.warn(*args, **kwargs)
Out[ ]:
tensor(0.6667, device='cuda:0')
In [ ]:
assert dice(dummpy_sample['logits'].sigmoid()>0.5, dummpy_sample['labels'].to(int)) == dice_score(dummpy_sample['logits'], dummpy_sample['labels'])

Давайте теперь пропишем лосс и воспользуемся библиотекой segmentation-models-pytorch, чтобы убедиться в корректности нашей функции.

In [ ]:
def dice_loss(logits: torch.Tensor, labels: torch.Tensor):
  eps = 1e-8

  probs = torch.sigmoid(logits)

  TP = (probs * labels).sum()
  FP = (probs * (1 - labels)).sum()
  FN = ((1 - probs) * labels).sum()

  dice = (2 * TP + eps) / (2 * TP + FP + FN + eps)
  return 1 - dice.mean()

Проверка на корректность:

In [ ]:
# проверьте, что у вас установлена библиотека
!pip install -q segmentation-models-pytorch
In [ ]:
from segmentation_models_pytorch.losses import DiceLoss
dice_loss_torch = DiceLoss(mode='binary')
dice_loss_torch(dummpy_sample['logits'], dummpy_sample['labels'])
Out[ ]:
tensor(0.5756, device='cuda:0')
In [ ]:
dice_loss(dummpy_sample['logits'], dummpy_sample['labels'])
Out[ ]:
tensor(0.5756, device='cuda:0')
In [ ]:
assert dice_loss_torch(dummpy_sample['logits'], dummpy_sample['labels'].to(int)) == dice_loss(dummpy_sample['logits'], dummpy_sample['labels'])

Focal Loss¶

2. Focal loss:

Окей, мы уже с вами умеем делать BCE loss:

$$\mathcal L_{BCE}(y, \hat y) = -\sum_i \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right].$$

Проблема с этой потерей заключается в том, что она имеет тенденцию приносить пользу классу большинства (фоновому) по отношению к классу меньшинства ( переднему). Поэтому обычно применяются весовые коэффициенты к каждому классу:

$$\mathcal L_{wBCE}(y, \hat y) = -\sum_i \alpha_i\left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right].$$

Традиционно вес $\alpha_i$ определяется как обратная частота класса этого пикселя $i$, так что наблюдения миноритарного класса весят больше по отношению к классу большинства.

Из оригинальной статьи по Focal Loss:

$$p_t = \sigma(\hat y_i)y_i + (1 - \sigma(\hat y_i)) (1-y_i)$$

$$\mathcal L_{focal}(y, \hat y) = (1 - p_t)^{\gamma} \mathcal L_{BCE}(y_i, \hat y_i).$$

$$\mathcal L_{focal}(y, \hat y) = -\sum_i (1 - p_t)^{\gamma} \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right].$$

$$\mathcal L_{focal}(y, \hat y) = -\sum_i (1 - (\sigma(\hat y_i)y_i + (1 - \sigma(\hat y_i)) (1-y_i)))^{\gamma} \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right].$$

In [ ]:
def focal_loss(y_real, y_pred, eps = 1e-6, gamma = 2):
  p = torch.sigmoid(y_pred) * y_real + (1 - torch.sigmoid(y_pred))*(1 - y_real)
  bce_loss = -(y_real * torch.log(torch.sigmoid(y_pred) + eps) + (1 - y_real) * torch.log(1 - torch.sigmoid(y_pred) + eps))
  loss = (1 - p)**gamma * bce_loss
  return loss.sum()

Проверка корректности функции:

In [ ]:
from torchvision.ops import sigmoid_focal_loss
sigmoid_focal_loss(dummpy_sample['logits'], dummpy_sample['labels'], alpha=-1, gamma=2, reduction='sum').item()
Out[ ]:
3.616123676300049
In [ ]:
assert torch.allclose(sigmoid_focal_loss(dummpy_sample['logits'], dummpy_sample['labels'], alpha=-1, gamma=2, reduction='sum'),
                      focal_loss(dummpy_sample['labels'], dummpy_sample['logits'], gamma=2.0),
                      rtol=1e-5,
                      atol=1e-8)

[BONUS] Мир сегментационных лоссов [5 баллов]¶

В данном блоке предлагаем вам написать одну функцию потерь самостоятельно. Для этого необходимо прочитать статью и имплементировать ее, и провести численное сравнение с предыдущими функциями.

  • Physiological Inspired Deep Neural Networks for Emotion Recognition". IEEE Access, 6, 53930-53943.

  • Boundary loss for highly unbalanced segmentation

  • Tversky loss function for image segmentation using 3D fully convolutional deep networks

  • Correlation Maximized Structural Similarity Loss for Semantic Segmentation

  • Topology-Preserving Deep Image Segmentation

Для изучения была выбрана статья Correlation Maximized Structural Similarity Loss for Semantic Segmentation

Основная идея Structural Similarity Loss:

Вместо того, чтобы смотреть на соответствие единичных пикселей, которые в свою очередь игнорируют зависимость между друг другом, этот метод предлагает смотреть, насколько коррелированы разные локальные участки предсказанной карты сегментации и истинной, и уделять внимание позициям, чьи предсказания приводят к низкой степени линейной корреляции.

Для реализации этой идеи авторы предлагают рассмотреть сумму обычной BCE loss и перевзвешинной BCE loss. К перевзвешенной при этом добавляют множитель ошибки - меру структурного сходства. Также для участков с маленькой ошибкой, перевзвешенная BCE loss будет зануляться, как бы предполагая, что на них все достаточно хорошо и ничего не нужно менять.

Для реализации этой функции потерь нужны следующие формулы:

Общая целевая функция:

$$ L_{all}(y, p) = \lambda L_{ce}(y, p) + (1-\lambda) L_{ssl}(y, p) \tag{17} $$

Классическая кросс-энтропия

$$ L_{ce}(y, p) = - \frac{1}{N} \sum_{n=1}^N \sum_{c=1}^C y_{n,c} \log(p_{n,c}) \tag{1} $$

$e$ — это общая абсолютная ошибка между стандартизованными нормализованными результатами истинных значений ($y$) и предсказаний ($p$). Это мера структурного различия (обратная корреляции). $C_4 = 0.01$ — стабилизирующий фактор.

$$ e = \left|\frac{y - \mu_y + C_4}{\sigma_y + C_4} - \frac{p - \mu_p + C_4}{\sigma_p + C_4}\right| \tag{10} $$

Маска для выбора примеров отбрасывает «легкие примеры» (те, для которых $e$ мало), тем самым реализуя стратегию Online Hard Example Mining (OHEM).

$$ f_{n,c} = 1_{\{e_{n,c} > \beta e_{\max}\}} \tag{11} $$

$L_{ssl}$ — это сигмоидальная кросс-энтропия $L_{ce}$, перевзвешенная структурной ошибкой $e$ и умноженная на маску $f_{n,c}$. $e_{n,c}$ используется как постоянный весовой коэффициент.

$$ L_{ssl}(y_{n,c}, p_{n,c}) = e_{n,c} f_{n,c} L_{ce}(y_{n,c}, p_{n,c}) \tag{12} $$

Итоговая функция потерь SSL по мини-батчу $L_{ssl}$ усредняется только по $M$ выбранным «трудным примерам». $$ L_{ssl}(y, p) = \frac{1}{M} \sum_{n=1}^N \sum_{c=1}^C L_{ssl}(y_{n,c}, p_{n,c}) \tag{13} $$

Для подсчета статистик необходимо будет пройтись гауссовым окном по картам сегментации.

Вспомогатльные формулы:

Локальное среднее: $$ \mu_y = \sum_{i=1}^{k^2} w_i y_i \tag{14} $$

Локальная дисперсия: $$ \sigma^2_y = \sum_{i=1}^{k^2} w_i(y_i - \mu_y)^2 = \sum_{i=1}^{k^2} w_i y^2_i - \mu^2_y \tag{15} $$

${w_i}$ - ${i}$ -тое значение в гауссовском окне (ядре)

In [ ]:
from typing import Tuple
import torch
import torch.nn.functional as F
In [ ]:
def make_gaussian_kernel(k: int, sigma: float):
  assert k % 2 == 1, "kernel size must be odd"
  half = k // 2
  xs = torch.arange(-half, half+1, dtype=torch.float32)
  ys = xs.view(-1, 1)
  kernel = torch.exp(-(xs**2 + ys**2) / (2 * sigma**2))
  kernel = kernel / kernel.sum()
  return kernel
In [ ]:
class BinaryStructuralSimilarityLoss(nn.Module):
  def __init__(self, window_size=3, tau=0.1, lambda_ce=0.9, c=0.01, eps=1e-8, gauss_sigma=1.5):
    super().__init__()
    self.k = window_size
    self.tau = tau
    self.lambda_ce = lambda_ce
    self.c = c
    self.eps = eps
    self.gauss_sigma = gauss_sigma

  def __get_kernel(self):
    return make_gaussian_kernel(self.k, self.gauss_sigma)

  def count_abs_structural_error(self, probs: torch.Tensor, labels: torch.Tensor):
    B = probs.shape[0]
    p_unf, y_unf = self.make_patches(probs, labels) # (B, k*k, H*W)

    kernel = self.__get_kernel()

    statistics_p = self.compute_statistics_over_patch(probs, kernel) # (B,1,H,W)
    statistics_y = self.compute_statistics_over_patch(labels, kernel)

    mean_p, sigma_p = self.adjust_dims(statistics_p, B) # (B, k*k, H*W)
    mean_y, sigma_y = self.adjust_dims(statistics_y, B)

    z_p = (p_unf - mean_p + self.c) / (sigma_p + self.c)
    z_y = (y_unf - mean_y + self.c) / (sigma_y + self.c)

    error = torch.abs(z_y - z_p).sum(dim=1)  # (B, H*W)
    return error

  def make_patches(self, probs: torch.Tensor, labels: torch.Tensor):
    pad = self.k // 2
    p_unf = F.unfold(probs, kernel_size=self.k, padding=pad) # (B, k*k, H*W)
    l_unf = F.unfold(labels, kernel_size=self.k, padding=pad) # (B, k*k, H*W)

    return p_unf, l_unf

  def compute_statistics_over_patch(self, labels: torch.Tensor, kernel: torch.Tensor, eps: float = 1e-8):
    pad = self.k // 2
    kernel = kernel.view(1, 1, self.k, self.k).to(labels.device)

    mean = F.conv2d(labels, kernel, padding=pad)  # (B,1,H,W)
    E2 = F.conv2d(labels * labels, kernel, padding=pad)
    var = torch.clamp(E2 - mean * mean, min=0.0)
    sigma = torch.sqrt(var + eps)  # (B,1,H,W)

    return mean, sigma

  def adjust_dims(self, statistics: Tuple[torch.Tensor, torch.Tensor], num_batches: int):
    k = self.k
    B = num_batches
    mean_exp = statistics[0].view(B, 1, -1)  # (B, 1, H*W)
    sigma_exp = statistics[1].view(B, 1, -1)  # (B, 1, H*W)
    mean_exp = mean_exp.repeat(1, k*k, 1)  # (B, k*k, H*W)
    sigma_exp = sigma_exp.repeat(1, k*k, 1)
    return mean_exp, sigma_exp

  def forward(self, logits: torch.Tensor, labels: torch.Tensor):
    probs = torch.sigmoid(logits)
    error = self.count_abs_structural_error(probs, labels)
    indicator = (error > self.tau * error.max(dim=1, keepdim=True)[0]).float()

    bce_full = F.binary_cross_entropy_with_logits(logits, labels, reduction='mean')

    bce_pixel = F.binary_cross_entropy_with_logits(logits, labels, reduction='none').view(probs.shape[0], -1) # (B, H*W)
    L_ssl_pixel = error * indicator * bce_pixel # (B, H*W)

    M = indicator.sum(dim=1).clamp(min=1) # (B,)
    L_ssl_batch = (L_ssl_pixel.sum(dim=1) / M) # (B,)
    L_ssl = L_ssl_batch.mean()

    L_all = self.lambda_ce * bce_full + (1 - self.lambda_ce) * L_ssl
    return L_all
In [ ]:
ssloss = BinaryStructuralSimilarityLoss()
ssloss(dummpy_sample['logits'], dummpy_sample['labels'])
Out[ ]:
tensor(0.5462, device='cuda:0')

Проведем численное сравнение с ранее использованными лоссами:

  1. Рандомный пример
In [ ]:
logits = torch.tensor([[[[
    -1.,  1.,  2.,  1., -1.,
     1., -2., -3., -2.,  1.,
     2., -3., -4., -3.,  2.,
     1., -2., -3., -2.,  1.,
    -1.,  1.,  2.,  1., -1.,
]]]], dtype=torch.float32)

labels = torch.tensor([[[[
    0., 1., 1., 1., 0.,
    1., 0., 0., 0., 1.,
    1., 0., 0., 0., 1.,
    1., 0., 0., 0., 1.,
    0., 1., 1., 1., 0.,
]]]], dtype=torch.float32)
In [ ]:
print("SS Loss:", ssloss(logits, labels))
print("Dice Loss:", dice_loss(logits, labels))
print("Focal Loss:", focal_loss(logits, labels))
print("BCE Loss:", bce_loss(logits, labels))
SS Loss: tensor(0.2240)
Dice Loss: tensor(0.1897)
Focal Loss: tensor(2.3315)
BCE Loss: tensor(4.9871)
  1. Пример с идеальным предсказанием
In [ ]:
labels = torch.tensor([[
    [0,0,0,0,0,0,0,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,0,0,0,0,0,0,0],
    [0,0,0,0,0,0,0,0,0,0]
]], dtype=torch.float)

logits = torch.where(labels==1, torch.tensor(10.0), torch.tensor(-10.0)).unsqueeze(0)
# shape: (1,1,10,10)
labels = labels.unsqueeze(0)
In [ ]:
print("SS Loss:", ssloss(logits, labels))
print("Dice Loss:", dice_loss(logits, labels))
print("Focal Loss:", focal_loss(logits, labels))
print("BCE Loss:", bce_loss(logits, labels))
SS Loss: tensor(4.1325e-05)
Dice Loss: tensor(8.1062e-05)
Focal Loss: tensor(-3668.2905)
BCE Loss: tensor(0.0046)
  1. Пример с ужасным предсказанием
In [ ]:
labels = torch.tensor([[
    [0,0,0,0,0,0,0,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,0,0,0,0,0,0,0],
    [0,0,0,0,0,0,0,0,0,0]
]], dtype=torch.float)

logits = torch.where(labels==1, torch.tensor(-10.0), torch.tensor(+10.0)).unsqueeze(0)
labels = labels.unsqueeze(0)
In [ ]:
print("SS Loss:", ssloss(logits, labels))
print("Dice Loss:", dice_loss(logits, labels))
print("Focal Loss:", focal_loss(logits, labels))
print("BCE Loss:", bce_loss(logits, labels))
SS Loss: tensor(22.0742)
Dice Loss: tensor(1.0000)
Focal Loss: tensor(9086.8047)
BCE Loss: tensor(1000.0046)

Т.к. механика всех этих функций потреь отличается, конечно, они не будут одинаковыми, но видно, что для всех примеров они в какой-то степени солидарны

Обучите SegNet на новых лоссах¶

Задание: обучите SegNet на новых лоссах и сравните все три лосса:

  • При каком лоссе модель сходится быстрее?
  • При каком лоссе модель выдает наилучшую метрику?

Напишите развернутый ответ на вопросы.

In [ ]:
model_dice = SegNet()
criterion = DiceLoss(mode='binary')
optimizer = torch.optim.Adam(model_dice.parameters(), lr = 1e-3)
model_dice, score_dice = train_func(model_dice, 50, loaders, optimizer, criterion)
In [ ]:
def focal_loss_wrapper(logits, labels):
    loss = sigmoid_focal_loss(logits, labels, reduction='none')
    return loss.mean()
In [ ]:
model_focal = SegNet()
criterion = focal_loss_wrapper
optimizer = torch.optim.Adam(model_focal.parameters(), lr = 1e-3)
model_focal, score_focal = train_func(model_focal, 50, loaders, optimizer, criterion)
In [ ]:
def ss_loss_wrapper(logits, labels):
  loss = ssloss(logits, labels)
  return loss.mean()
In [ ]:
model_ssl = SegNet()
criterion = ss_loss_wrapper
optimizer = torch.optim.Adam(model_ssl.parameters(), lr = 1e-3)
model_ssl, score_ssl = train_func(model_ssl, 50, loaders, optimizer, criterion)
In [ ]:
def to_float_list(x):
    return [float(v) for v in x]  # работает и для CUDA tensors

score_baseline['val_iou'] = to_float_list(score_baseline['val_iou'])
score_dice['val_iou']      = to_float_list(score_dice['val_iou'])
score_focal['val_iou']     = to_float_list(score_focal['val_iou'])
score_ssl['val_iou']        = to_float_list(score_ssl['val_iou'])
In [ ]:
plt.plot(range(1, 51), score_baseline['val_iou'], label='score baseline')
plt.plot(range(1, 51), score_dice['val_iou'], label='score dice')
plt.plot(range(1, 51), score_focal['val_iou'], label='score focal')
plt.plot(range(1, 51), score_ssl['val_iou'], label='score ssl')


plt.title('IoU на валидационной выборке')
plt.xlabel('epoch')
plt.ylabel('IoU score')
plt.legend()
Out[ ]:
<matplotlib.legend.Legend at 0x78bc651fb410>
No description has been provided for this image
In [ ]:
plt.plot(range(1, 51), score_baseline['val_loss'], label='score baseline')
plt.plot(range(1, 51), score_dice['val_loss'], label='score dice')
plt.plot(range(1, 51), score_focal['val_loss'], label='score focal')
plt.plot(range(1, 51), score_ssl['val_loss'], label='score ssl')

plt.ylim(0, 5)

plt.title('Loss на валидационной выборке')
plt.xlabel('epoch')
plt.ylabel('Loss')
plt.legend()
Out[ ]:
<matplotlib.legend.Legend at 0x78bc64063200>
No description has been provided for this image
In [ ]:
test_dice = test(model_dice, test_dataloader)
test_focal = test(model_focal, test_dataloader)
test_ssl = test(model_ssl, test_dataloader)
In [ ]:
print("Модель с Dice Loss имеет Iou:", test_dice.cpu().item())
print("Модель с Focal Loss имеет Iou:", test_focal.cpu().item())
print("Модель с SS Loss имеет Iou:", test_ssl.cpu().item())
Модель с Dice Loss имеет Iou: 0.8408982753753662
Модель с Focal Loss имеет Iou: 0.8346421718597412
Модель с SS Loss имеет Iou: 0.681615948677063

Наилучшую сходимость показал focal: он сразу же вышел на минимальную ошибку и единственный стабильно ее сохранял.

Наилучшую метрику на тесте выдал Dice, но стоит заметить, что Focal не сильно отстал, а на валидационной выборке был самым стабильным.

Задание: Новая модель!¶

Модель U-Net [2 балла]¶

U-Net — это архитектура нейронной сети, которая получает изображение и выводит его. Первоначально он был задуман для семантической сегментации (как мы ее будем использовать), но он настолько успешен, что с тех пор используется в других контекстах. Получая на вход медицинское изображение, он выведет изображение в оттенках серого, где интенсивность каждого пикселя зависит от вероятности того, что этот пиксель принадлежит интересующей нас области.

No description has been provided for this image

У нас в архитектуре все так же существует энкодер и декодер, как в SegNet, но отличительной особеностью данной модели являются skip-conenctions, соединяющие части декодера и энкодера. То есть для того чтобы передать на вход декодера тензор, мы конкатенируем симметричный выход с энкодера и выход предыдущего слоя декодера.

  • Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-Net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.

В оригинальной статье авторы не использовали padding внутри модели (это видно по тому, что размеры карты признаков уменьшаются на 2 каждый раз при движении от слоя к слою). При этом размеры входных изображений авторы единоразово увеличили при помощи mirror padding.

В этом домашнем задании вы можете применить альтернативный подход - сохранять размеры карт признаков при помощью padding = 1 во внутренних слоях.

In [ ]:
import torch.nn.functional as F
import torch.nn as nn

Для реализации UNet вы можете написать классы блоков энкодера и декодера отдельно, как мы сделали при реализации SegNet.

In [ ]:
class UNetEncoder(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
    )
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

  def forward(self, x):
    x_conv = self.conv(x)
    x_down = self.pool(x_conv)
    return x_conv, x_down
In [ ]:
class UNetDecoder(nn.Module):
  def __init__(self, in_channels, out_channels, padding=1):
    super().__init__()

    self.up_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    self.conv = nn.Sequential(
        nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
    )

  def forward(self, x, skip):
    x = self.up_conv(x)
    x = torch.cat([skip, x], dim=1)
    return self.conv(x)
In [ ]:
class UNet(nn.Module):
    def __init__(self, n_class=1):
        super().__init__()

        self.e0 = UNetEncoder(3, 64)
        self.e1 = UNetEncoder(64, 128)
        self.e2 = UNetEncoder(128, 256)
        self.e3 = UNetEncoder(256, 512)
        self.e4 = UNetEncoder(512, 1024)

        self.d0 = UNetDecoder(1024, 512)
        self.d1 = UNetDecoder(512, 256)
        self.d2 = UNetDecoder(256, 128)
        self.d3 = UNetDecoder(128, 64)

        self.final_conv = nn.Conv2d(64, n_class, kernel_size=1)

    def forward(self, x):
        x0_conv, x0_down = self.e0(x)
        x1_conv, x1_down = self.e1(x0_down)
        x2_conv, x2_down = self.e2(x1_down)
        x3_conv, x3_down = self.e3(x2_down)
        x4_conv, x4_down = self.e4(x3_down)

        x = self.d0(x4_conv, x3_conv)
        x = self.d1(x, x2_conv)
        x = self.d2(x, x1_conv)
        x = self.d3(x, x0_conv)

        output = self.final_conv(x)
        return output

Обучите UNet¶

Задание: обучите UNet на всех трех лоссах: BCE, Dice, Focal и сравните результаты с SegNet:

  • Какая модель дает лучшие значения по метрике?
  • Какая модель дает лучшие значения по лоссам?
  • Какая модель обучается быстрее?
  • Сравните визуально результаты SegNet и UNet.

Напишите развернутый ответ на вопросы.

  1. Обучим U-Net на Focal Loss
In [ ]:
unet_model = UNet().to(DEVICE)
criterion = focal_loss_wrapper
optimizer = torch.optim.Adam(unet_model.parameters(), lr = 1e-3)
unet_model, unet_score_focal = train_func(unet_model, 50, loaders, optimizer, criterion)
  1. Обучим U-Net на Dice Loss
In [ ]:
def dice_loss_wrapper(logits, labels):
    loss = dice_loss(logits, labels)
    return loss.mean()
In [ ]:
unet_model_dice = UNet().to(DEVICE)
criterion = dice_loss_wrapper
optimizer = torch.optim.Adam(unet_model_dice.parameters(), lr = 3e-4)
unet_model_dice, unet_score_dice = train_func(unet_model_dice, 50, loaders, optimizer, criterion)
  1. Обучим U-Net на BCE Loss
In [ ]:
unet_model_bce = UNet().to(DEVICE)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(unet_model_bce.parameters(), lr = 1e-3)
unet_model_bce, unet_score_bce = train_func(unet_model_bce, 50, loaders, optimizer, criterion)
In [ ]:
test_unet_dice = test(unet_model_dice, test_dataloader)
test_unet_focal = test(unet_model, test_dataloader)
test_unet_bce = test(unet_model_bce, test_dataloader)
In [ ]:
print("Модель с Dice Loss имеет Iou:", test_unet_dice.cpu().item())
print("Модель с Focal Loss имеет Iou:", test_unet_focal.cpu().item())
print("Модель с BCE Loss имеет Iou:", test_unet_bce.cpu().item())
Модель с Dice Loss имеет Iou: 0.3579278588294983
Модель с Focal Loss имеет Iou: 0.025996873155236244
Модель с BCE Loss имеет Iou: 0.5349061489105225
In [ ]:
def to_float_list(x):
    return [float(v) for v in x]

unet_score_bce['val_iou'] = to_float_list(unet_score_bce['val_iou'])
unet_score_dice['val_iou']      = to_float_list(unet_score_dice['val_iou'])
unet_score_focal['val_iou']     = to_float_list(unet_score_focal['val_iou'])
In [ ]:
plt.plot(range(1, 51), unet_score_bce['val_iou'], label='score baseline')
plt.plot(range(1, 51), unet_score_dice['val_iou'], label='score dice')
plt.plot(range(1, 51), unet_score_focal['val_iou'], label='score focal')

plt.title('IoU на валидационной выборке для модели U-Net')
plt.xlabel('epoch')
plt.ylabel('IoU score')
plt.legend()
Out[ ]:
<matplotlib.legend.Legend at 0x7f73d739ae10>
No description has been provided for this image